import torch
import torchvision
import torch.nn as nn
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torch import optim
from torch.autograd import Variable
import numpy as np
from Trace import LossScaledTrace
from CreateCNN import CreateCNN
from CreateResNet import CreateResNet
from DataSet import Dataset
#from main import train_data
import time
import copy
import sys

def ComputeTraces(epoch, ModelArchitecture):
    folder = 'GD04'
    train_size = 50000
    B = 75
    PATH = 'Saved01/{}/epoch{}.pt'.format(folder, epoch)
    train_data, _ = Dataset(train_size)
    if ModelArchitecture == 'ResNet':
        model = CreateResNet()
    elif ModelArchitecture == 'CNN':
        model = CreateCNN()
    model.load_state_dict(torch.load(PATH))
    #model.eval()
    #total_params = sum(p.numel() for p in model.parameters())
    total_params = 0
    for p in model.parameters():
        if p.requires_grad:
            total_params += p.numel()
    ProductTrace, Frobenius, HessianTrace = LossScaledTrace(test_model=model,
                                                            train_data=train_data,
                                                            d=total_params,
                                                            train_size=train_size,
                                                            B=B)
    file = open('Saved01/{}/Traces/epoch{}.trace'.format(folder, epoch), 'w')
    #file.write(str([epoch, ProductTrace, Frobenius, HessianTrace]))
    file.write('{}\n'.format(epoch))
    file.write('{}\n'.format(ProductTrace))
    file.write('{}\n'.format(Frobenius))
    file.write('{}\n'.format(HessianTrace))
    file.close()

torch.manual_seed(12)
epoch = sys.argv[1]
ModelArchitecture = sys.argv[2]
ComputeTraces(epoch, ModelArchitecture)